实操教程|PyTorch AutoGrad C++层实现 您所在的位置:网站首页 pytorch 20 实操 实操教程|PyTorch AutoGrad C++层实现

实操教程|PyTorch AutoGrad C++层实现

2023-08-24 06:21| 来源: 网络整理| 查看: 265

本文为一篇实操教程,作者介绍了PyTorch AutoGrad C++层实现中各个概念的解释。

autograd依赖的数据结构

at::Tensor:shared ptr 指向 TensorImpl

TensorImpl:对 at::Tensor 的实现

包含一个类型为 [AutogradMetaInterface](c10::AutogradMetaInterface) 的autograd_meta_,在tensor是需要求导的variable时,会被实例化为 [AutogradMeta](c10::AutogradMetaInterface) ,里面包含了autograd需要的信息

Variable: 就是Tensor,为了向前兼容保留的

using Variable = at::Tensor;

概念上有区别, Variable 是需要计算gradient的, Tensor 是不需要计算gradient的

Variable的 AutogradMeta是对 [AutogradMetaInterface](c10::AutogradMetaInterface)的实现,里面包含了一个 Variable,就是该variable的gradient

带有version和view

会实例化 AutogradMeta , autograd需要的关键信息都在这里

AutoGradMeta : 记录 Variable 的autograd历史信息

包含一个叫grad_的 Variable, 即 AutoGradMeta 对应的var的梯度tensor

包含类型为 Node 指针的 grad_fn (var在graph内部时)和 grad_accumulator(var时叶子时), 记录生成grad_的方法

包含 output_nr ,标识var对应 grad_fn的输入编号

构造函数包含一个类型为 Edge的gradient_edge, gradient_edge.function 就是 grad_fn, 另外 gradient_edge.input_nr 记录着对应 grad_fn的输入编号,会赋值给 AutoGradMeta 的 output_nr

autograd::Edge: 指向autograd::Node的一个输入

包含类型为 Node 指针,表示edge指向的Node

包含 input_nr, 表示edge指向的Node的输入编号

autograd::Node: 对应AutoGrad Graph中的Op

是所有autograd op的抽象基类,子类重载apply方法

next_edges_记录出边

input_metadata_记录输入的tensor的metadata

实现的子类一般是可求导的函数和他们的梯度计算op

Node in AutoGrad Graph

Variable通过Edge关联Node的输入和输出

多个Edge指向同一个Var时,默认做累加

call operator

最重要的方法,实现计算

next_edge

缝合Node的操作

获取Node的出边,next_edge(index)/next_edges()

add_next_edge(),创建

前向计算

PyTorch通过tracing只生成了后向AutoGrad Graph.

代码是生成的,需要编译才能看到对应的生成结果

gen_variable_type.py生成可导版本的op

生成的代码在 pytorch/torch/csrc/autograd/generated/

前向计算时,进行了tracing,记录了后向计算图构建需要的信息

这里以relu为例,代码在pytorch/torch/csrc/autograd/generated/VariableType_0.cpp

Tensor relu(const Tensor & self) { auto& self_ = unpack(self, "self", 0); std::shared_ptr grad_fn; if (compute_requires_grad( self )) { // 如果输入var需要grad // ReluBackward0的类型是Node grad_fn = std::shared_ptr(new ReluBackward0(), deleteNode); // collect_next_edges(var)返回输入var对应的指向的 // grad_fn(前一个op的backward或者是一个accumulator的)的输入的Edge // set_next_edges(),在grad_fn中记录这些Edge(这里完成了后向的构图) grad_fn->set_next_edges(collect_next_edges( self )); // 记录当前var的一个版本 grad_fn->self_ = SavedVariable(self, false); } #ifndef NDEBUG c10::optional self__storage_saved = self_.has_storage() ? c10::optional(self_.storage()) : c10::nullopt; c10::intrusive_ptr self__impl_saved; if (self_.defined()) self__impl_saved = self_.getIntrusivePtr(); #endif auto tmp = ([&]() { at::AutoNonVariableTypeMode non_var_type_mode(true); return at::relu(self_); // 前向计算 })(); auto result = std::move(tmp); #ifndef NDEBUG if (self__storage_saved.has_value()) AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage())); if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr()); #endif if (grad_fn) { // grad_fn增加一个输入,记录输出var的metadata作为grad_fn的输入 // 输出var的AutoGradMeta实例化,输出var的AutoGradMeta指向起grad_fn的输入 set_history(flatten_tensor_args( result ), grad_fn); } return result; }

 

可以看到和 grad_fn 相关的操作trace了一个op的计算,构建了后向计算图.

后向计算

autograd::backward():计算output var的梯度值,调用的run_backward()

autograd::grad() :计算有output var和到特定input的梯度值,调用的run_backward()

autograd::run_backward()

对于要求梯度的output var,获取其指向的grad_fn作为roots,是后向图的起点

对于有input var的,获取其指向的grad_fn作为output_edges, 是后向图的终点

调用 autograd::Engine::get_default_engine().execute(...) 执行后向计算

autograd::Engine::execute(...)

创建 GraphTask ,记录了一些配置信息

创建 GraphRoot ,是一个Node,把所有的roots作为其输出边,Node的apply()返回的是roots的grad【这里已经得到一个单起点的图】

计算依赖 compute_dependencies(...)

从GraphRoot开始,广度遍历,记录所有碰到的grad_fn的指针,并统计grad_fn被遇到的次数,这些信息记录到GraphTask中

GraphTask 初始化:当有input var时,判断后向图中哪些节点是真正需要计算的

GraphTask 执行

选择CPU or GPU线程执行

以CPU为例,调用的 autograd::Engine::thread_main(...)

autograd::Engine::thread_main(...)

evaluate_function(...) ,输入输出的处理,调度

call_function(...) , 调用对应的Node计算

执行后向过程中的生成的中间grad Tensor,如果不释放,可以用于计算高阶导数;(同构的后向图,之前的grad tensor是新的输出,grad_fn变成之前grad_fn的backward,这些新的输出还可以再backward)

具体的执行机制可以支撑单独开一个Topic分析,在这里讨论到后向图完成构建为止.



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有